# -*- coding: utf-8 -*- 
"""
@Author: 孟颖
@email: 652044581@qq.com
@date: 2023/4/20 10:19
# @desc: 加密解密模块
"""
import os
from Crypto import Random
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from Crypto.PublicKey import RSA
import base64, hashlib
from Crypto.Cipher import PKCS1_v1_5
from gmssl import sm2, sm4
from enum import Enum


class HashCipher:

    @staticmethod
    def md5(message):
        """
        Returns the MD5 hash of the message
        """
        return hashlib.md5(message.encode()).hexdigest()

    @staticmethod
    def sha1(message):
        """
        Returns the SHA1 hash of the message
        """
        return hashlib.sha1(message.encode()).hexdigest()

    @staticmethod
    def sha256(message):
        """
        Returns the SHA256 hash of the message
        """
        return hashlib.sha256(message.encode()).hexdigest()

    @staticmethod
    def sha512(message):
        """
        Returns the SHA512 hash of the message
        """
        return hashlib.sha512(message.encode()).hexdigest()


class Base64Cipher:
    @staticmethod
    def bytes_to_base64(bytes_data):
        """
        Converts bytes to base64 string
        """
        return base64.b64encode(bytes_data).decode('utf-8')

    @staticmethod
    def base64_to_bytes(base64_string):
        """
        Converts base64 string to bytes
        """
        return base64.b64decode(base64_string.encode('utf-8'))


class AESCipher:
    def __init__(self, key, iv):
        """定义初始化加密和解密key和iv偏移"""
        self.key = key.encode('utf-8')
        self.iv = iv

    def encrypt(self, plaintext):
        """AES加密"""
        cipher = AES.new(self.key, AES.MODE_CBC, self.iv)
        padded_plaintext = pad(plaintext.encode('utf-8'), AES.block_size)
        ciphertext = cipher.encrypt(padded_plaintext)
        return base64.b64encode(ciphertext).decode('utf-8')

    def decrypt(self, ciphertext):
        """AES解密"""
        cipher = AES.new(self.key, AES.MODE_CBC, self.iv)
        decoded_ciphertext = base64.b64decode(ciphertext)
        padded_plaintext = cipher.decrypt(decoded_ciphertext)
        plaintext = unpad(padded_plaintext, AES.block_size)
        return plaintext.decode('utf-8')


class RSACipher:

    @staticmethod
    def generate(path):
        """RSA生成的路径"""
        private_path = os.path.join(path, "private.pem")
        public_path = os.path.join(path, "public.pem")
        if not os.path.exists(private_path) and not os.path.exists(public_path):
            rsa = RSA.generate(1024, Random.new().read)
            # 私钥生成并保存
            private_pem = rsa.exportKey()
            with open(private_path, "wb") as f:
                f.write(private_pem)

            # 公钥生成并保存
            public_pem = rsa.publickey().exportKey()
            with open(public_path, "wb") as f:
                f.write(public_pem)

    @staticmethod
    def encrypt(message: str, public_path=None, public_key=None):
        """RSA 加密（使用公钥加密）"""
        if public_path:
            public_key = RSA.importKey(open(public_path).read())
        cipher = PKCS1_v1_5.new(public_key)
        cipher_text = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
        return cipher_text.decode('utf-8')

    @staticmethod
    def decrypt(message: str, private_path=None, private_key=None):
        """RSA 解密（使用私钥解密）"""
        encrypt_text = message.encode('utf-8')
        if private_path:
            private_key = RSA.importKey(open(private_path).read())
        cipher = PKCS1_v1_5.new(private_key)
        text = cipher.decrypt(base64.b64decode(encrypt_text), "解密失败")
        return text.decode('utf-8')


class SM2Cipher:

    def __int__(self, public_key=None, private_key=None):
        """SM2国密 加解密（使用私钥解密）"""
        self.public_key = public_key
        self.private_key = private_key
        self.EncryptClient = sm2.CryptSM2(public_key=self.public_key, private_key=self.private_key)

    def encrypt(self, message: str) -> bytes:
        """SM2加密, 返回加密后的16进制字符串"""
        return self.EncryptClient.encrypt(message.encode('utf-8')).hex()

    def decrypt(self, message: str) -> str:
        """SM2解密, 加密后的字节码的16进制"""
        return self.EncryptClient.decrypt(bytes.fromhex(message)).decode('utf-8')


class Sm4Cipher:
    """Sm4 加解密算法"""

    class CipherModeEnum(Enum):
        ECB = "ecb"
        CBC = "cbc"

    def __init__(self, key, iv=None, cipher_mode: str = CipherModeEnum.CBC.value):
        """
        :param cipher_mode: ecb or cbc
        """
        if cipher_mode not in Sm4Cipher.CipherModeEnum._value2member_map_:
            raise ValueError('cipher mode just support cbc or ecb')
        self.cipher_mode = cipher_mode
        if self.cipher_mode == 'cbc':
            self.key = key.encode()
        self.iv = iv.encode()

    def encrypt(self, plaintext: str) -> str:
        """
        :param plaintext: 加密数据
        :return: byte的hex进制
        """

        enc_data = plaintext.encode()
        crypt_sm4 = sm4.CryptSM4()
        crypt_sm4.set_key(self.key, sm4.SM4_ENCRYPT)
        if self.cipher_mode == 'ecb':
            encrypt_value = crypt_sm4.crypt_ecb(enc_data)
        else:
            encrypt_value = crypt_sm4.crypt_cbc(self.iv, enc_data)
        return encrypt_value.hex()

    def decrypt(self, encrypted_text: str) -> str:
        """
        :param encrypted_text:
        :return: 字符串
        """
        crypt_sm4 = sm4.CryptSM4()
        crypt_sm4.set_key(self.key, sm4.SM4_DECRYPT)
        if self.cipher_mode == 'ecb':
            decrypt_value = crypt_sm4.crypt_ecb(bytes.fromhex(encrypted_text))
        else:
            decrypt_value = crypt_sm4.crypt_cbc(self.iv, bytes.fromhex(encrypted_text))
        return decrypt_value.decode(encoding='utf-8')


if __name__ == '__main__':
    # hashlib模块加密
    print(HashCipher.md5("asqwdasd"))

    # AES加密
    key = 'mysecretpassword'
    message1 = '032306'
    cipher = AESCipher(key, iv=b'0000000000000000')
    encrypted_message = cipher.encrypt(message1)
    print('Encrypted message:', encrypted_message)
    decrypted_message = cipher.decrypt(encrypted_message)
    print('Decrypted message:', decrypted_message)

    # RSA加密/解密
    RSACipher.generate(r"C:\Users\my\Desktop\qwe")
    s = RSACipher.encrypt('Hello, world!', r"C:\Users\my\Desktop\qwe\public.pem")
    print(RSACipher.decrypt(s, r"C:\Users\my\Desktop\qwe\private.pem"))
